Draft

Aalen’s Dynamic Path Model

Causal Inference with Time Varying Effects in PyMC

path-models
sem
causal inference

hello world

import pymc as pm
import numpy as np 
import pandas as pd
import arviz as az
import pytensor.tensor as pt
from scipy.interpolate import BSpline

If you look to Odysseus on the morning the gates of Troy fell, he is well set up for a happy journey home. He is the architect of victory, his ships are loaded with spoils, and the wind is at his back. Yet, an odyssey can’t be completed in a single day and conclusions drawn on the outset rarely survive journey’s end.

When we rely on static snapshots, like a single blood draw or particular sales campaign

df = pd.read_csv("aalen_simdata.csv")
df = df[['subject', 'x', 'dose', 'M', 'start', 'stop', 'event']]
df.head()
subject x dose M start stop event
0 1 0 ctrl 6.74 0 4.00 0
1 1 0 ctrl 6.91 4 8.00 0
2 1 0 ctrl 6.90 8 12.00 0
3 1 0 ctrl 6.71 12 26.00 0
4 1 0 ctrl 6.45 26 46.85 1
df.groupby(['x', 'dose'])[['event', 'M']].agg(['mean', 'sum'])
event M
mean sum mean sum
x dose
0 ctrl 0.164179 66 6.996915 2812.76
1 high 0.119205 54 8.081589 3660.96
low 0.139037 52 7.302620 2731.18
Code
import matplotlib.pyplot as plt
import pandas as pd

# Derive subject-level info for ordering
subject_info = (
    df.groupby('subject')
      .agg(
          x=('x', 'first'),
          max_stop=('stop', 'max')
      )
      .sort_values(['x', 'max_stop'])
)

subjects = subject_info.index.tolist()
subject_to_y = {s: i for i, s in enumerate(subjects)}

fig, ax = plt.subplots(figsize=(8, 0.1 * len(subjects)))

for _, row in df.iterrows():
    y = subject_to_y[row['subject']]
    
    color = 'tab:blue' if row['x'] == 1 else 'tab:orange'
    
    ax.hlines(
        y=y,
        xmin=row['start'],
        xmax=row['stop'],
        color=color,
        linewidth=3
    )
    
    if row['event'] == 1:
        ax.plot(
            row['stop'],
            y,
            marker='o',
            color='red',
            markersize=6,
            zorder=3
        )

# Axis formatting
ax.set_yticks(range(len(subjects)))
ax.set_yticklabels(subjects)
ax.set_xlabel("Time")
ax.set_ylabel("Subject")

# Visual separation between treatment groups
x0_count = (subject_info['x'] == 0).sum()
ax.axhline(x0_count - 0.5, color='black', linestyle='--', linewidth=1)

# Legend
from matplotlib.lines import Line2D

legend_elements = [
    Line2D([0], [0], color='tab:blue', lw=3, label='x = 1'),
    Line2D([0], [0], color='tab:orange', lw=3, label='x = 0'),
    Line2D([0], [0], marker='o', color='red', lw=0, label='Event', markersize=6)
]

ax.legend(handles=legend_elements, loc='upper right')

ax.set_title("Subject Timelines Ordered by Treatment Level")

plt.tight_layout()
plt.show()

Data Preparation

def prepare_aalen_dpa_data(
    df,
    subject_col="subject",
    start_col="start",
    stop_col="stop",
    event_col="event",
    x_col="x",
    m_col="M",
):
    """
    Prepare Andersen–Gill / Aalen dynamic path data for PyMC.

    Parameters
    ----------
    df : pd.DataFrame
        Long-format start–stop survival data
    subject_col : str
        Subject identifier
    start_col, stop_col : str
        Interval boundaries
    event_col : str
        Event indicator (0/1)
    x_col : str
        Exposure / treatment
    m_col : str
        Mediator measured at interval start

    Returns
    -------
    dict
        Dictionary of numpy arrays ready for PyMC
    """

    df = df.copy()

    # -------------------------------------------------
    # 1. Basic quantities
    # -------------------------------------------------
    df["dt"] = df[stop_col] - df[start_col]

    if (df["dt"] <= 0).any():
        raise ValueError("Non-positive interval lengths detected.")

    N = df[event_col].astype(int).values
    Y = np.ones(len(df), dtype=int)  # Andersen–Gill at-risk indicator

    # -------------------------------------------------
    # 2. Time-bin indexing (piecewise-constant effects)
    # -------------------------------------------------
    bins = (
        df[[start_col, stop_col]]
        .drop_duplicates()
        .sort_values([start_col, stop_col])
        .reset_index(drop=True)
    )
    bins["bin_idx"] = np.arange(len(bins))

    df = df.merge(
        bins,
        on=[start_col, stop_col],
        how="left",
        validate="many_to_one"
    )

    bin_idx = df["bin_idx"].values
    n_bins = bins.shape[0]

    # -------------------------------------------------
    # 3. Center covariates (important for Aalen models)
    # -------------------------------------------------
    df["x_c"] = df[x_col]
    df["m_c"] = df[m_col] - df[m_col].mean()

    x = df["x_c"].values
    m = df["m_c"].values

    # -------------------------------------------------
    # 4. Predictable mediator (lag within subject)
    # -------------------------------------------------
    df = df.sort_values([subject_col, start_col])

    df["m_lag"] = (
        df.groupby(subject_col)["m_c"]
          .shift(1)
          .fillna(0.0)
    )

    m_lag = df["m_lag"].values

    df["I_low"]  = (df["dose"] == "low").astype(int)
    df["I_high"] = (df["dose"] == "high").astype(int)

    # -------------------------------------------------
    # 5. Assemble output
    # -------------------------------------------------
    data = {
        "bins": bins,     # useful for plotting
        "df_long": df     # optional: debugging / inspection
    }

    return data
data = prepare_aalen_dpa_data(df)
df_long = data['df_long']
df_long[['subject', 'x', 'dose', 'M', 'event', 'dt', 'bin_idx']].head(14)
subject x dose M event dt bin_idx
0 1 0 ctrl 6.74 0 4.00 7
1 1 0 ctrl 6.91 0 4.00 13
2 1 0 ctrl 6.90 0 4.00 23
3 1 0 ctrl 6.71 0 14.00 53
4 1 0 ctrl 6.45 1 20.85 81
5 2 1 high 6.11 0 4.00 7
6 2 1 high 6.28 0 4.00 13
7 2 1 high 7.04 0 4.00 23
8 2 1 high 6.93 0 14.00 53
9 2 1 high 7.86 0 26.00 89
10 2 1 high 8.47 0 26.00 115
11 2 1 high 8.91 0 26.00 137
12 2 1 high 8.99 0 52.00 162
13 2 1 high 9.36 0 104.00 188
def create_bspline_basis(n_bins, n_knots=10, degree=3):
    """
    Create B-spline basis functions for smooth time-varying effects.
    
    Parameters
    ----------
    n_bins : int
        Number of time bins
    n_knots : int
        Number of internal knots (fewer = smoother)
    degree : int
        Degree of spline (3 = cubic, recommended)
    
    Returns
    -------
    basis : np.ndarray
        Matrix of shape (n_bins, n_basis) with basis function values
    """
    # Create knot sequence
    # Internal knots equally spaced across time range
    internal_knots = np.linspace(0, n_bins-1, n_knots)
    
    # Add boundary knots (repeated degree+1 times for clamped spline)
    knots = np.concatenate([
        np.repeat(internal_knots[0], degree),
        internal_knots,
        np.repeat(internal_knots[-1], degree)
    ])
    
    # Number of basis functions
    n_basis = len(knots) - degree - 1
    
    # Evaluate each basis function at each time point
    t = np.arange(n_bins, dtype=float)
    basis = np.zeros((n_bins, n_basis))
    
    for i in range(n_basis):
        # Create coefficient vector (indicator for basis i)
        coef = np.zeros(n_basis)
        coef[i] = 1.0
        
        # Evaluate B-spline
        spline = BSpline(knots, coef, degree, extrapolate=False)
        basis[:, i] = spline(t)
    
    return basis

n_knots = 10
n_bins = data['bins'].shape[0]
basis = create_bspline_basis(n_bins, n_knots=n_knots, degree=3)
n_cols = basis.shape[1]
basis_df = pd.DataFrame(basis, columns=[f'feature_{i}' for i in range(n_cols)])
basis_df.head(10)
feature_0 feature_1 feature_2 feature_3 feature_4 feature_5 feature_6 feature_7 feature_8 feature_9 feature_10 feature_11
0 1.000000 0.000000 0.000000 0.000000 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0
1 0.863149 0.133496 0.003337 0.000018 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0
2 0.739389 0.247518 0.012946 0.000146 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0
3 0.628064 0.343219 0.028223 0.000494 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0
4 0.528515 0.421749 0.048566 0.001170 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0
5 0.440083 0.484261 0.073370 0.002286 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0
6 0.362110 0.531908 0.102032 0.003950 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0
7 0.293939 0.565840 0.133949 0.006272 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0
8 0.234909 0.587211 0.168518 0.009362 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0
9 0.184365 0.597171 0.205134 0.013330 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0
def make_model(data, basis, sample=True, observed=True): 
    df_long = data['df_long'].copy()
    n_basis = basis.shape[1]
    n_obs = data['df_long'].shape[0]
    time_bins = data['bins']['bin_idx'].values
    b = df_long['bin_idx']

    observed_mediator = df_long["m_c"].values
    observed_events = df_long['event'].astype(int).values
    observed_treatment = df_long['x'].astype(int).values
    observed_mediator_lag = df_long['m_lag'].values

    coords = {'tv': ['intercept', 'direct', 'mediator'], 
            'splines': ['spline_f_{i}' for i in range(n_basis)], 
            'obs': range(n_obs), 
            'time_bins': time_bins}

    with pm.Model(coords=coords) as aalen_dpa_model:

        trt = pm.Data("trt", observed_treatment, dims="obs")
        med = pm.Data("mediator", observed_mediator, dims="obs")
        med_lag = pm.Data("mediator_lag", observed_mediator_lag, dims="obs")
        events = pm.Data("events", observed_events, dims="obs")
        I_low  = pm.Data("I_low",  df_long["I_low"].values,  dims="obs")
        I_high = pm.Data("I_high", df_long["I_high"].values, dims="obs")
        dt = pm.Data("duration", df_long['dt'].values, dims='obs')
        ## because our long data format has a cell per obs
        at_risk = pm.Data("at_risk", np.ones(len(observed_events)), dims="obs")
        basis_ = pm.Data("basis", basis, dims=('time_bins', 'splines') )

        # -------------------------------------------------
        # 1. B-spline coefficients for HAZARD model
        # -------------------------------------------------
        # Prior on spline coefficients
        # Smaller sigma = less wiggliness
        # Random Walk 1 (RW1) Prior for coefficients
        # This is the Bayesian version of the smoothing penalty in R's 'mgcv' or 'timereg'
        sigma_smooth = pm.Exponential("sigma_smooth", [1, 1, 1], dims='tv')
        beta_raw = pm.Normal("beta_raw", 0, 1, dims=('splines', 'tv'))

        # Cumulative sum makes it a Random Walk
        # This ensures coefficients evolve smoothly over time
        coef_alpha = pm.Deterministic("coef_alpha", pt.cumsum(beta_raw * sigma_smooth, axis=0), dims=('splines', 'tv'))

        # Construct smooth time-varying functions
        alpha_0_t = pt.dot(basis_, coef_alpha[:, 0])
        alpha_1_t = pt.dot(basis_, coef_alpha[:, 1])
        alpha_2_t = pt.dot(basis_, coef_alpha[:, 2])
        
        # -------------------------------------------------
        # 2. B-spline coefficients for MEDIATOR model
        # -------------------------------------------------
        sigma_beta_smooth = pm.Exponential("sigma_beta_smooth", 0.1)
        beta_raw = pm.Normal("beta_raw_m", 0, 1, dims=('splines'))
        coef_beta = pt.cumsum(beta_raw * sigma_beta_smooth)
        
        beta_t = pt.dot(basis_, coef_beta)

        # -------------------------------------------------
        # 3. Mediator model (A path: x → M)
        # -------------------------------------------------
        sigma_m = pm.HalfNormal("sigma_m", 1.0)
        
        # Autoregressive component
        rho = pm.Beta("rho", 2, 2)
        
        mu_m = beta_t[b] * trt + rho * med_lag

        pm.Normal(
            "obs_m",
            mu=mu_m,
            sigma=sigma_m,
            observed=med,
            dims='obs'
        )

        # -------------------------------------------------
        # 4. Hazard model (direct + B path)
        # -------------------------------------------------
        beta_low  = pm.Normal("beta_low",  0, 0.1)
        beta_high = pm.Normal("beta_high", 0, 0.1)
        # Log-additive hazard
        log_lambda_t = (alpha_0_t[b] 
                        + alpha_1_t[b] * trt # direct effect
                        + alpha_2_t[b] * med  # mediator effect
                        + beta_low  * I_low
                        + beta_high * I_high
        )
        
        # Expected number of events
        time_at_risk = at_risk * dt
        Lambda = time_at_risk * pm.math.log1pexp(log_lambda_t)

        if observed:
            pm.Poisson(
                "obs_event",
                mu=Lambda,
                observed=events, 
                dims='obs'
            )
        else: 
            pm.Poisson(
                "obs_event",
                mu=Lambda,
                dims='obs'
            )

        # -------------------------------------------------
        # 5. Causal path effects
        # -------------------------------------------------
        # Store time-varying coefficients
        pm.Deterministic("alpha_0_t", alpha_0_t, dims='time_bins')
        pm.Deterministic("alpha_1_t", alpha_1_t, dims='time_bins')  # direct effect
        pm.Deterministic("alpha_2_t", alpha_2_t, dims='time_bins')  # B path
        pm.Deterministic("beta_t", beta_t, dims='time_bins')        # A path
        
        # Cumulative direct effect
        cum_de = pm.Deterministic(
            "tv_direct_effect",
            alpha_1_t, 
            dims='time_bins'
        )

        # Cumulative indirect effect (product of paths)
        cum_ie = pm.Deterministic(
            "tv_indirect_effect",
            beta_t * alpha_2_t, 
            dims='time_bins'
        )

        # Total effect
        cum_te = pm.Deterministic(
            "tv_total_effect",
            cum_de + cum_ie,
            dims='time_bins'
        )

        pm.Deterministic('tv_baseline_hazard', pm.math.log1pexp(alpha_0_t), 
            dims='time_bins')

        pm.Deterministic('tv_hazard_with_exposure', pm.math.log1pexp(alpha_0_t + alpha_1_t), 
            dims='time_bins')

        pm.Deterministic(
        "tv_RR",
        pm.math.log1pexp(alpha_0_t + alpha_1_t) /
        pm.math.log1pexp(alpha_0_t),
        dims="time_bins"
        )

        # -------------------------------------------------
        # 6. Sample
        # -------------------------------------------------
        if sample:
            idata = pm.sample_prior_predictive()
            idata.extend(pm.sample(
                draws=2000,
                tune=2000,
                target_accept=0.95,
                chains=4,
                nuts_sampler="numpyro",
                random_seed=42,
                init="adapt_diag", 
                idata_kwargs={"log_likelihood": True}
            ))
            idata.extend(pm.sample_posterior_predictive(idata))
    
    return aalen_dpa_model, idata

basis = create_bspline_basis(n_bins, n_knots=12, degree=3)
aalen_dpa_model, idata_aalen =  make_model(data, basis)
pm.model_to_graphviz(aalen_dpa_model)

models = {}
idatas = {}
for i in range(4, 15, 2):
    basis = create_bspline_basis(n_bins, n_knots=i, degree=3)
    aalen_dpa_model, idata = make_model(data, basis)
    models[i] = aalen_dpa_model
    idatas[f"splines_{i}"] = idata

compare_df = az.compare(idatas, var_name='obs_event')
az.plot_compare(compare_df, figsize=(8, 6), plot_ic_diff=True)

ax = az.plot_forest([idatas[k] for k in idatas.keys()], combined=True, var_names=['tv_direct_effect'], model_names=idatas.keys(), coords={'time_bins': [180, 182, 182, 183, 184, 185, 186, 187, 188]}, 
figsize=(12, 10),  r_hat=True)
ax[0].set_title("Time Vary Direct Effects \n Comparing Models on Final Time Intervals", fontsize=15)
ax[0].set_ylabel("Nth Time Interval", fontsize=15)
fig = ax[0].figure
fig.savefig('forest_plot_comparing_tv_direct.png')

az.plot_trace(idata_aalen, var_names=['tv_direct_effect', 'tv_indirect_effect', 'tv_total_effect', 'beta_high', 'beta_low'], divergences=False);
plt.tight_layout()

vars_to_plot = ['tv_direct_effect', 'tv_indirect_effect', 'tv_total_effect']
labels = ['Time varying Direct Effect', 'Time varying Indirect Effect', 'Time varying Total Effect']

def plot_effects(idata, vars_to_plot, labels, scale="Log Hazard Ratio Scale"):
    fig, axs = plt.subplots(1, 3, figsize=(20, 10))
    color='teal'
    if scale != "Log Hazard Ratio Scale":
        color='darkred'

    for i, var in enumerate(vars_to_plot):
        # 1. Extract the posterior samples for this variable
        # Shape will be (chain * draw, time)
        post_samples = az.extract(idata, var_names=[var]).values.T
        
        # 2. Calculate the mean and the 94% HDI across the chains/draws
        mean_val = post_samples.mean(axis=0)
        hdi_val = az.hdi(post_samples, hdi_prob=0.94) # Returns [time, 2] array
        
        # 3. Plot the Mean line
        x_axis = np.arange(len(mean_val))
        axs[i].plot(x_axis, mean_val, label=labels[i], color=color, lw=2)
        
        # 4. Plot the Shaded HDI region
        axs[i].fill_between(x_axis, hdi_val[:, 0], hdi_val[:, 1], color=color, alpha=0.2, label='94% HDI')
        
        # Formatting
        axs[i].set_title(labels[i])
        axs[i].legend()
        axs[i].grid(alpha=0.3)
        axs[i].set_ylabel(scale)
    plt.tight_layout()
    return fig

plot_effects(idata_aalen, vars_to_plot, labels);
/var/folders/__/ng_3_9pn1f11ftyml_qr69vh0000gn/T/ipykernel_13629/3910338117.py:17: FutureWarning: hdi currently interprets 2d data as (draw, shape) but this will change in a future release to (chain, draw) for coherence with other functions
  hdi_val = az.hdi(post_samples, hdi_prob=0.94) # Returns [time, 2] array
/var/folders/__/ng_3_9pn1f11ftyml_qr69vh0000gn/T/ipykernel_13629/3910338117.py:17: FutureWarning: hdi currently interprets 2d data as (draw, shape) but this will change in a future release to (chain, draw) for coherence with other functions
  hdi_val = az.hdi(post_samples, hdi_prob=0.94) # Returns [time, 2] array
/var/folders/__/ng_3_9pn1f11ftyml_qr69vh0000gn/T/ipykernel_13629/3910338117.py:17: FutureWarning: hdi currently interprets 2d data as (draw, shape) but this will change in a future release to (chain, draw) for coherence with other functions
  hdi_val = az.hdi(post_samples, hdi_prob=0.94) # Returns [time, 2] array

vars_to_plot = ['tv_baseline_hazard', 'tv_hazard_with_exposure', 'tv_RR']
labels = ['Time varying Baseline Hazard', 'Time varying Hazard + Exposure', 'Time varying RR']
plot_effects(idata_aalen, vars_to_plot, labels, scale='Hazard Scale');
/var/folders/__/ng_3_9pn1f11ftyml_qr69vh0000gn/T/ipykernel_13629/3910338117.py:17: FutureWarning: hdi currently interprets 2d data as (draw, shape) but this will change in a future release to (chain, draw) for coherence with other functions
  hdi_val = az.hdi(post_samples, hdi_prob=0.94) # Returns [time, 2] array
/var/folders/__/ng_3_9pn1f11ftyml_qr69vh0000gn/T/ipykernel_13629/3910338117.py:17: FutureWarning: hdi currently interprets 2d data as (draw, shape) but this will change in a future release to (chain, draw) for coherence with other functions
  hdi_val = az.hdi(post_samples, hdi_prob=0.94) # Returns [time, 2] array
/var/folders/__/ng_3_9pn1f11ftyml_qr69vh0000gn/T/ipykernel_13629/3910338117.py:17: FutureWarning: hdi currently interprets 2d data as (draw, shape) but this will change in a future release to (chain, draw) for coherence with other functions
  hdi_val = az.hdi(post_samples, hdi_prob=0.94) # Returns [time, 2] array

Citation

BibTeX citation:
@online{forde,
  author = {Forde, Nathaniel},
  title = {Aalen’s {Dynamic} {Path} {Model}},
  langid = {en}
}
For attribution, please cite this work as:
Forde, Nathaniel. n.d. “Aalen’s Dynamic Path Model.”